{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dcf6ea6f-b730-4ae2-accd-7e3b4277f11f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5d42dca2-5588-4004-acd7-0a219f5ddada",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 创建数据\n",
    "torch.manual_seed(100)\n",
    "x = torch.linspace(-30,30,100)\n",
    "x = (x - x.mean()) / x.std()\n",
    "epsilon = torch.randn(x.shape)\n",
    "y = 10 * x + 5 + epsilon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3b5a488b-2be2-475c-8050-930e1580a8ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x1ed82d17050>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAisAAAGdCAYAAADT1TPdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA5K0lEQVR4nO3df3RU9Z3/8deEQgKYjEKEGRQhFbSNsQooBtEq1FCopWq7XfHXF7suR1G2B9x+RfTbQuoPxK3inqWlZ6vVPYeqPWf9tSyWIx4QC8aCAkcx/oIGoZoUCTiDCAkm9/tHetNk5s7MvZN7596ZeT7OyTmbyZ2bT4bpztvP5/0jZBiGIQAAgIAq8XsBAAAA6RCsAACAQCNYAQAAgUawAgAAAo1gBQAABBrBCgAACDSCFQAAEGgEKwAAINC+4vcC+qqzs1OffPKJysvLFQqF/F4OAACwwTAMHT58WCNGjFBJSfq9k7wPVj755BONHDnS72UAAIAs7Nu3T6eeemraa/I+WCkvL5fU9cdWVFT4vBoAAGBHPB7XyJEjuz/H08n7YMU8+qmoqCBYAQAgz9hJ4SDBFgAABBrBCgAACDSCFQAAEGgEKwAAINAIVgAAQKARrAAAgEAjWAEAAIFGsAIAAAIt75vCAQAAb3R0GtrSdFD7Dx/TsPIyTawaon4luZ/DR7ACAACSrN3ZrPrVjWqOHet+LBou0+KZ1ZpeE83pWjgGAgAAvazd2ay5q7b1ClQkqSV2THNXbdPanc05XQ/BCgAAUEenoYbdrXpu219013M7ZVhcYz5Wv7pRHZ1WV3iDYyAAAPKYG3klVkc+qRiSmmPHtKXpoCadPjTLVTtDsAIAQJ6ym1eSLqAxj3yc7pPsP5w5sHELwQoAAHkoVZDRHDumW1Zt002TR+uy6ogOHWnXPWusA5q66ojqVzc6DlQkaVh5WZ/W7wTBCgAAeaaj08gYZDy2eY8e27zH8mdmouz8y8baOvrpKSQpEu7anckVEmwBAMgzW5oOOg4yejKDnMdTBDOpmJkwi2dW57TfCjsrAAAEgJNEWTfyRQxJnx097ug5EZ/6rBCsAADgM6cN2NzMFzlxYH/Fjh5PeaQ0ZHB//fS7ZylS4V8HW46BAADwUTYN2CZWDVE0XCY3woYfTa6SpKR7hf72df9VZ+uqcado0ulDfQlUJIIVAAB8ky5RNl0Dtn4lIS2eWS0pOciwK6Su3Zt5U8do5fXjFQn33q2JhMu08vrxOT/yscIxEAAAPsmUKJuuAdv0mqhWXj/edjO3nhITZafXRFVXHQnE0EIrBCsAAPjEbqJsqut6BhnrGlv02817FJIy9k2xSpTtVxLKWUdapzw9Blq6dKnOP/98lZeXa9iwYbryyiv1/vvv97rGMAwtWbJEI0aM0MCBA3XppZfqnXfe8XJZAAAEgt1E2XTXmUHGz2aepV9bHOdEw2X61bXj9NScWv37rHP11JxabVo4NRDHO3Z5urOyceNG3XbbbTr//PP15Zdf6u6779a0adPU2NiowYMHS5IefPBBPfzww3riiSd0xhln6N5771VdXZ3ef/99lZeXe7k8AAB8ZSbKtsSOWe6GOG3AFvTjnGyFDMPI2djETz/9VMOGDdPGjRv1zW9+U4ZhaMSIEZo/f74WLlwoSWpra9Pw4cO1bNky3XzzzRnvGY/HFQ6HFYvFVFFR4fWfAACAq8xqIKn38Y0ZXgQlydVtTj6/c1oNFIvFJElDhnRFiE1NTWppadG0adO6ryktLdUll1yi1157zfIebW1tisfjvb4AAMhXZqJskKtx/JazBFvDMHT77bfroosuUk1NjSSppaVFkjR8+PBe1w4fPlwfffSR5X2WLl2q+vp6bxcLAIDHEjvWbvy/U/TmR4cK6vjGLTkLVubNm6e33npLmzZtSvpZKNT7H8MwjKTHTIsWLdLtt9/e/X08HtfIkSPdXSwAAB5K17H2inNP8XFlwZSTY6B/+Zd/0f/8z/9ow4YNOvXUU7sfj0Qikv6+w2Lav39/0m6LqbS0VBUVFb2+AADIF9l0rC12ngYrhmFo3rx5evbZZ7V+/XpVVVX1+nlVVZUikYjWrVvX/Vh7e7s2btyoCy+80MulAQCQc9l2rPViHQ27W/XCjo/VsLvV89/XV54eA91222168skn9cILL6i8vLx7ByUcDmvgwIEKhUKaP3++7r//fo0dO1Zjx47V/fffr0GDBunaa6/1cmkAAORcXzrWSs4mM6fidGhiEHgarKxcuVKSdOmll/Z6/PHHH9eNN94oSbrjjjt09OhR3XrrrTp06JAuuOACvfTSS/RYAQAUnL50rHUjyDCPoBL3UcwjqKBWH+W0z4oX6LMCAMgXDbtbdc1vXs943VNzanvtrKQKMpz0YunoNHTRsvUpd3bMBnSbFk7NSRVSYPusAABQzMyOtalCAXMScs+OtW7luTg5ggoaghUAAHKkX0lIi2dWS1JSwJI4CdnkVpDR16GJfiJYAQDAJjeqaJx2rHUryHBjaKJfctYUDgCAfOZmFY3dgYMdnYYOHG6zdc9MQYbbQxNziWAFAIAM+lpFk6rk2Ko8uefvTAyOrNgNMswjqLmrtikk66GJiUdQQUGwAgBAGpkSXEPqSnCtq45YftBnsyOTKjhK5DTIMI+gEtcTKeY+KwAA5Lu+NHLLZkcmXXCUKJsgw+4RVJAQrAAAkEa2Ca7Z7shkCo5MP73867pxclVWQUamI6igoRoIAIA0nFbRmBVDy9e9n1XJsd3gqLK8NNC7IW5iZwUAgDScVNHYTYrtKTE4yecSY6+wswIAQBp2G7mta2zR3FXbHAUqUnLQkU2X20JHsAIAQAaZGrnVVUdsJ8WaUgUd2XS5LXQcAwEAYEO6KpqG3a2OdlQyBR35WmLsFYIVAACUunFbT6mqaJzO07ETdORjibFXCFYAAEWvr6307Sa7zpsyRpPHVNoOOvKtxNgr5KwAAIqa2bgt8RjHbNy2dmdzxnvYTYpdUHeGJp0+tCh3R/qCYAUAULQyNW6Tuhq3ZZquTFKstwhWAAAFx2zM9sKOj9WwuzVlsOGklX4mmSqGii0p1k3krAAACoqT/JNsW+mnQlKsNwhWAAAFw+ngQDe6xVpVEZEU6y6CFQBAQchmcKCTVvpW+lpFBHvIWQEAFIRs8k8yJcYakmadP1L/+9YnSbkvblQRwR52VgAABcFp/ol5fNP2ZafmX3aGntqyVy3xv98jPKi/JGn5yx92P2bumqRrr59qFwfZI1gBABQEJ/knVsc3kYpSLbhsrEZXDtaeA1/okZc/SJn7Mv+ysbZ3cchf6TuOgQAABcFuY7ZDR9otj2/+Gm/TIy9/qP4lIT29dW/a3iuPb95ja01O2/DDGsEKAKAg2GnM9tPLv6571qRvAvf/XtiZcdfks6PHba3J7m4P0iNYAQAUjEyN2U4aXJoxEDl4xF4gcuLA/hl3cVJVEcEZclYAAAUlXWO2F3Z87Nrv+dHkKj3y8gfdVUMm2uu7j2AFAFBwUk0rtnssM2TwAB060p6298q8qWN0ZuSE5ERd+qy4jmAFAFA07DaB++nl1brtyW0Zd01or58b5KwAAIqG3enI3/mG/aGE5i7OFeeeokmnDyVQ8UDIMIz0c68DLh6PKxwOKxaLqaKiwu/lAADygN02+VZzfwhG3OHk85tgBQBQlAhE/OXk85ucFQBAXss26EiVhIvgIVgBAOQtph4XBxJsAQB5ycnU445OQw27W/XCjo+Tpicj+NhZAQDkDfPIpyV2VPesedfW1ON1jS3svuQ5ghUAQF6wOvJJxZx6vGL9rrTTkxPLkBFMHAMBAAIv1ZFPJo9vbko7tLB+dSNHQnnA02Dl1Vdf1cyZMzVixAiFQiE9//zzvX5+4403KhQK9fqqra31ckkAgDzT0WmofrX1pORM0k1HNndftjQdzHptyA1Pg5UjR47onHPO0YoVK1JeM336dDU3N3d/vfjii14uCQCQA24mtG5pOuh4RyWkrqnIduw/7OzeyD1Pc1ZmzJihGTNmpL2mtLRUkUjEy2UAAHLI7XJip8GE2WHlR5NHa/nLH2a83u5wQ/jH95yVV155RcOGDdMZZ5yhOXPmaP/+/Wmvb2trUzwe7/UFAAgGJ+XEdjkNJsz5PfOmjlU0XJY0A8gUUlcQNbFqiOM1Ibd8DVZmzJih3/3ud1q/fr0eeughbd26VVOnTlVbW1vK5yxdulThcLj7a+TIkTlcMQAglXS5JX1JaDUnJafrSTtkcH8tv/pcPTWnVpsWTtX0mqjtoYW02A8+X4OVq6++Wpdffrlqamo0c+ZM/eEPf9AHH3ygNWvWpHzOokWLFIvFur/27duXwxUDAFLJlFuSbUJrpqAjJOn+q87WVeOSpx5Pr7E/PRnBFag+K9FoVKNGjdKHH6Y+YywtLVVpaWkOVwUAsMNubkk2Ca1m0JGYCxOxkQszvSaquuoIQwvzWKCCldbWVu3bt0/RKJEuAARVqsGBdnNLUl2XaSBhX4IOhhbmN0+Dlc8//1y7du3q/r6pqUk7duzQkCFDNGTIEC1ZskQ/+MEPFI1GtWfPHt11112qrKzUVVdd5eWyAABZSlfpU1cdUTRcppbYsZQ9UU4c2F+dhqGOTqNXkGG3goigoziFDMPwrHXfK6+8oilTpiQ9Pnv2bK1cuVJXXnmltm/frs8++0zRaFRTpkzRPffc4yhpNh6PKxwOKxaLqaKiws3lAwB6MCt9Ej80zJBj5fXjJUlzV22TpLRN3KLhMv308q/rpMGlWtfYot9u3pN0Tc/7kltSeJx8fnsarOQCwQoAeK+j09BFy9anTKANqSt/ZNPCqZaDA7PV877kmBQWJ5/fvvdZAQAEn5NKn+k1UW1aOFW/u+kC211k7dwXxYtgBQCQkdNKn34lIZWUhNLO5vHi96MwEawAADLKptLHzQCDlvjFjWAFAJBRpi6yVq3r3QgwaIkPiWAFAGBDNq3r7bTJT4eW+DARrAAAbHHauj5dgGMHLfFhonQZAOBIpk6ziawavqVz0+TRuqw6Qkv8AkefFQBAoCQGOIeOtOueNZk71qJwOfn8DtRsIABAYbJqk//tGoYLwh6CFQCAL5jzA7tIsAUAAIFGsAIAAAKNYAUAAAQaOSsAgMByWiaNwkSwAgAIJKv+LJQ3FyeOgQCgAHV0GmrY3aoXdnysht2t6ujMr5Zaa3c2a+6qbUmN5FpixzR31Tat3dns08rgB3ZWAKDAWO1IRCpKdc3E0zS6cnDgj1M6Og3Vr26UVXhlqKt1f/3qRtVVRwL7N8BdBCsAUEDMHYnED/qWeJuWv/xh9/dBPk7Z0nQwbWt+Q1Jz7Ji2NB2kT0uR4BgIAApEuh2JREE+Ttl/2N4MIbvXIf8RrABAgci0I9GTGdDUr24MXD7LsPKyzBc5uA75j2AFAAqE052GnscpQTKxaoii4TKlykYJqesYa2LVkFwuCz4iWAGAApHtTkPQjlP6lYS0eGa1JCUFLOb3i2dWk1xbRAhWACDAnJQgZ9qRSKWvxylelElPr4lq5fXjFQn3XlskXKaV148PZGIwvEM1EAAElNOmaOaOxNxV2xSSMibahtT14d+X4xQvG7dNr4mqrjpCB1soZBhGsDKrHIrH4wqHw4rFYqqoqPB7OQDgilQlyObHdLrdBasAIpGd+3i5RsDJ5zfBCgAETEenoYuWrU8ZbJg7IpsWTk25y9Bzps6eA1/oqS171RJ3b/fDjTWiuDn5/OYYCAACxo2maP1KQr1+Nm/qGFePU2jchlwiWAGAgPGiKVpi8GLFyYRjGrchlwhWACBgctUUrS9HRTRuQy4RrABAwJglyC2xY5YVPV5V8SQyW/JbJcrmYo2AiT4rABAwXjdFM6t4MrXmN/72dddzb+u57b17qNC4DblENRAABJQXPUwyVfFkkvj7veyzgsJG6TIAFAgnSa92NOxu1TW/eT3r51v1UHF7jSgOlC4DQIGwU8XjRF+rcwx1BSz1qxtVVx1Rv5KQ62sEEpGzAgBFxI3qnKBOa0bhIlgBgCKS7bBDK/RQQa4QrABAEUlXxeMUPVSQKwQrAFBkptdEtfL68YqEewcbkYpSLbhsrJb/4zkaMnhAymAmpK6KH3qoIFdIsAWAIjS9Jqq66kjKKp6BA/pp7qptCkm9mr7RQwV+oHQZAGCJHirwkpPPb0+PgV599VXNnDlTI0aMUCgU0vPPP9/r54ZhaMmSJRoxYoQGDhyoSy+9VO+8846XSwIA2DS9JqpNC6fqqTm1+vdZ5+qpObXatHAqgQpyztNg5ciRIzrnnHO0YsUKy58/+OCDevjhh7VixQpt3bpVkUhEdXV1Onz4sJfLAgDYZPZQueLcUzTp9KEc/cAXnuaszJgxQzNmzLD8mWEYeuSRR3T33Xfr+9//viTpv/7rvzR8+HA9+eSTuvnmm71cGgB4KpddXekgi0LnW4JtU1OTWlpaNG3atO7HSktLdckll+i1115LGay0tbWpra2t+/t4PO75WgHAiVzmepBXgmLgW+lyS0uLJGn48OG9Hh8+fHj3z6wsXbpU4XC4+2vkyJGerhMAnEg10bgldkxzV23T2p3Nefm7AD/53mclFOq9VWkYRtJjPS1atEixWKz7a9++fV4vEQBs6eg0VL+6UVYlluZj9asb1dHZtyLMjk5Dmz88oDufedvz3wUEgW/HQJFIRFLXDks0+vetyv379yfttvRUWlqq0tJSz9cHAE5taTqYtMvRU8+ZOtkO/rM69vHqdwFB4dvOSlVVlSKRiNatW9f9WHt7uzZu3KgLL7zQr2UBQNbszsrJdqZOqmMfL34XECSe7qx8/vnn2rVrV/f3TU1N2rFjh4YMGaLTTjtN8+fP1/3336+xY8dq7Nixuv/++zVo0CBde+21Xi4LADxhd1ZOuutSVfakO2JyY01AkHkarLzxxhuaMmVK9/e33367JGn27Nl64okndMcdd+jo0aO69dZbdejQIV1wwQV66aWXVF5e7uWyAMAT5kTjltgxy6AiJCmSZqZOusqe8MABjnZUMv0uIJ/Qbh8AXGQe1UjWM3VWXj/esqTYfF7i/0M2n/dPk0frsc17bK0h0+8CgiAw7fYBoNiknGgcLksZPNipInpux8e215DudwH5iKnLAOCyTBONTWZ+yuZdn2asIjp45LiGDB6gQ0faU+atnDiwv3553XjVfpW2+CgsBCsAkEa2rezNmTqp2C1B7unKc0fo8c17FJL1EdMDPzhbk8dU2r4fkC8IVgAgBa9a2afKT8mkrjqiiVVDktYUob0+ChzBCgBYSBVQmK3ss80JyaYEuWdlT7+SkK0jJqCQEKwAQIJMCa8hdbWyr6uOOA4SMnW5TWTeffHM6u7flemICSg0VAMBQAInbfPt6ug01LC7VX9wOFyQyh6AnRUASOJ22/xskmnnTRmjyWMqOeIBRLACoIilqvRxo22+yWkyrZmfsqDuDIIU4G8IVgAUpXSVPnXVkT61zTc5Taa1yk8BQM4KgCKUanqxWemzrrFFi2dWS/p7AGFyElA4TaYlPwWwxs4KgKJit9Jn08KpWnn9+D71NLGb0/J/Jo3SjJoo+SlACgQrAIqKk0ofu23zU7Gb+zKjJkopMpAGwQqAouK00qcvPU0mVg1xJfcFKHbkrAAoKm5W+mTSryTkSu4LUOwIVgAUFXO3I1V4EFJXVVC63Q6zwdsLOz5Ww+5WdXSmrveZXhPVyuvHKxLuHfyQTAvYxzEQgKJi7nbMXbUt5fTidLsd2Qw37GvuC1DsQoZhOB38GSjxeFzhcFixWEwVFRV+LwdAnsgm6EjV4M0MOdgpAexz8vnNzgqAouR0t8PL4YYA0iNYAVC0nFT62C15fmJzk26cXEXAAriIBFsAsMFuyfM9a97VRcvWa63D6coAUiNYAQAbnJQym237CVgAdxCsAIANmUqeezLzWupXN6YtawZgD8EKANiQrsGblZ5t+wH0DcEKANiUqsFbOnZzXQCkRjUQgLzT0Wn41mDNLHl+YnOT7lnzbsbr3WjbDxQ7ghUAeSWbZm5u61cS0o2Tq/TopiaGFAI5wDEQgLxhdpBN7HfiR/UNQwqB3CFYARBo5tDA57b9RXc9tzNlB1kp99U3DCkEcoNjIACBZXXkk0rP6hu7XWndwJBCwHsEKwACKdXQwEzM6ptcJuE6adsPwDmCFQCBk25oYCbDyssCkYQLwD3krAAInExDA62E1BWQHDrSHpgkXADuIFgBEDjZNFIzJE0/a7j+3wvBSsIF0HcEKwACx2kjNTMV5fHXPtLBI+0pr6MFPpCfCFYAOGaWE7+w42M17G51fafCztDAIYP760eTR0uSnP76P+xs9mTdALwRMgwjr//XGo/HFQ6HFYvFVFFR4fdygIKXq+RVsxpIUq9jHTOA+eW143TPmncd57b0ZLVuP1v5A8XEyec3wQoA21KVE5sf5W43QksXGIUHDtA1v3m9T/dPXDdVREDuOPn8pnQZgC3pyokNdX3w169uVF11xLWdiHQN117Y8XGf799z3Z2d0m1PJgdiZhURHWkB//ies7JkyRKFQqFeX5FIxO9lAUiQqZzYq+RVs+HaFeeeokmnD+0OhNyaZmyumyoiILh8D1Yk6ayzzlJzc3P319tvv+33kgAksFtOnE3ZcTbsJuHeUDvK1v2oIgKCKxDByle+8hVFIpHur5NPPtnvJQFIYHcnw60dj0wyTT0OSbr/qrP1nbPdO7rJVSAGoLdABCsffvihRowYoaqqKs2aNUt//vOfU17b1tameDze6wuA9zLtZJgdZCdWDcnZmuxMPbaz7iGD+9v6fbkKxAD05ns10B/+8Ad98cUXOuOMM/TXv/5V9957r9577z298847Gjo0eTDYkiVLVF9fn/Q41UCA9zKVE/uVhJqp3NhuGXRL7Jhl3kpIXQHQpoVTKWMGXJLXpctHjhzR6aefrjvuuEO333570s/b2trU1tbW/X08HtfIkSMJVoAcydfy3kzrDmogBhSqvA5WJKmurk5jxozRypUrM15LnxUg9/K1cZqdHZh8DMSAfJTXfVba2tr07rvv6uKLL/Z7KQBSMMuJ802mdafr6wLAP74HKz/5yU80c+ZMnXbaadq/f7/uvfdexeNxzZ492++lAShC+RqIAYXM92DlL3/5i6655hodOHBAJ598smpra/X6669r1Ch7vREAAEBh8z1Yefrpp/1eAgAPWOWHSMp4xJKv+TAAvON7sAKg8Fglqp44qKuXyWdfHO9+LDF5lQRXAFYCWQ3kBNVAQLCkmsxspWdZsKScTnQG4K+8rgYCkL/STWa2Yl638L/fUr9+JTmd6AwgfwSi3T6AwpBpMnMqsWNfMkgQQEoEKwBc4/WgPwYJAsWJYyAAaTmpzvF60B+DBIHiRLACICWn1TnmhONUAwGzZQ4SzOVEZwDBwTEQAEtmVU9iDkpL7JjmrtqmtTubez1u7sDMqIl0J8W6wbzP4pnVJNcCRYqdFQBJ0lX1WFXnWO3AhEJSz8YIVn1W7IjQZwUoegQrAJJkqurpWZ0TO9pu2R+l828P3DR5tC6rjnQf4by+u1W3PblNnx1NHbQMGdxfP/3uWYpU0MEWAMdAACzYrbppiR1N21clJOnFnS3dAUe/kpAmj63UAz84WyElHxWZj91/1dm6atwpmnT6UAIVAAQrAJLZrbo5eKTd9g5MT9Nrolp5/XhFwr1/TyRcRqdaAEk4BgKQJFNVj1mdM+SEUlv3s9qpmV4TVV11hKGFADIiWAGQpF9JSItnVmvuqm0KSb0Clp7VOeGBA2zdL9VOTb+SkCadPrRPawVQ+DgGAmDJzlGNuQOTai8kpK6+LPRHAdAX7KwASCnTUY3dHRiOdgD0BcEKUMSsWulLSnos3VGNuQOT2GeF/igA3EKwAhQpq0ZuVo3b0rXXN5EsC8BLIcMw3BzhkXPxeFzhcFixWEwVFRV+LwfIC2YrfTv/4zfDDUqKAbjJyec3CbZAkUnXSt+KeV396kZ1dOb1f9sAyFMEK0CRydRK30qq5m4AkAsEK0CRsdtK3+3nAkC2CFaAImO3lb7bzwWAbFENBBQIqzLkxGqcjk5DnZ2GThzYP+3U40Rme32auwHwA8EKUACsypATS46trrGD5m4A/EawArjEzs6GF1KVIbfEjmnuqm1aef14SbJVqmzVZ4XmbgD8RrACuMDOzoYX0pUhG+raFVnyP+9ICqUNVE4c2F+/vG68ar/a1amW5m4AgoRgBegjOzsbXgUsmcqQDUkt8baM9/ns6HGVhELdQQmTkAEECdVAQB9k2tmQvG2m5mYpMWXJAIKKYAXoAzs7G142U3OzlJiyZABBRbAC9IHd3Qivdi0mVg1RNFymVBklIUmRilJFKtJfE6UsGUCAEawAfWB3N8KrXYt+JSEtnlktSUnBiPn9ku+dpSXfS38NZckAgoxgBegDOzsbXu9aTK+JauX14xUJ9w6IIuGy7uReO9cAQFCFDMPI6zGqTkZMA14wq4Ek9Uq0NQOYXAUDdjvYUpYMIAicfH4TrAAu8KvPCgDkKyef3/RZAVwwvSaquuoIOxsA4AGCFcAl/UpCaZupsfsCANkhwRbIATOvJbEni9nldu3OZp9WBgDBR7ACeMzvLrcAkO8IVgCP+d3lFgDyXSCClV/96leqqqpSWVmZJkyYoD/+8Y9+Lwlwjd9dbgEg3/kerPz+97/X/Pnzdffdd2v79u26+OKLNWPGDO3du9fvpQGu8LvLLQDkO9+DlYcfflg33XST/vmf/1lf//rX9cgjj2jkyJFauXKl30sDXBGELrcAkM98DVba29v15ptvatq0ab0enzZtml577TXL57S1tSkej/f6AoIs0/weQ9Ks80fqf9/6RA27W9X+ZacadrfqhR0fq2F3K4m3AIqer31WDhw4oI6ODg0fPrzX48OHD1dLS4vlc5YuXar6+vpcLA9wjTmbJ7HPSnhQf0nS8pc/7H6sJCT1jE/oxQKg2Pl+DCRJoVDv/940DCPpMdOiRYsUi8W6v/bt25eLJQJ9Nr0mqk0Lp+qpObX691nnasFlZyj2xXF99sXxXtclbqTQiwVAsfN1Z6WyslL9+vVL2kXZv39/0m6LqbS0VKWlpblYHuA6s8ttR6ehi5att+y9ksi85q7n3tbR452KVNhr009rfwCFwtdgZcCAAZowYYLWrVunq666qvvxdevW6YorrvBxZYC3MvVesXLwyHEt+P0OSZmPhmjtD6CQ+H4MdPvtt+vRRx/Vb3/7W7377rtasGCB9u7dq1tuucXvpQGe6WtPlebYMd2yapvuWf1OUhIurf0BFBrfBxleffXVam1t1c9//nM1NzerpqZGL774okaNGuX30gDPuNVT5bHNe/TY5j3duyZ11ZG0rf1D6mrtX1cd4UgIQN4IGYaR13WR8Xhc4XBYsVhMFRUVfi8HsMXMWWmJHbOVt5KJGXbMv2xsr8qiVJ6aU5t2QjQAeM3J57fvx0BAMUrXeyUbZsDz+OY9tq6ntT+AfEKwAvjE7L0SCfc+Esr2dMaQ9NnR4xmvk2jtDyC/+J6zAhSz6TVR1VVHepUYTxh1kt786JBaYkd1z5p3dehIu6OjohMH9lfs6HHL54QkRWjtDyDPEKwAPjN7r/Rkfj9wQD/NXbWtuy2/HT+aXKVHXv4g6Tnmhs3imdUk1wLIKxwDAQGW6qjIijkQcd7UMZbPiYTLtPL68fRZAZB3qAYC8oDZjXZdY4t+u3lPyl2TnsEIHWwBBJmTz2+OgYAEQfyQN4+KJp0+VBOrhiR1p41YdKe1Ol4CgHxEsAL0kA9t6q2ScoMQUAGAVzgGAv7GbFOf+D8IjlgAwH0cAwEOdXQattvUr2tsCfzuCwAUEqqBAGWegmyoa3jgivW7GBIIADlGsALIfvv5xzc3pdx9kbp2X3pOQAYA9B3BCiD77efTtbM3d1+2NB10aVUAAIlgBZAkTawaomi4LOVQwZC62tjbwZBAAHAXwQoKRkenoYbdrXphx8dq2N3q6Dgm3RRk8/sfTR5t614MCQQAd1ENhILQl/4oZhly25edmn/ZGXpqy161xJMbrtVVR/T01n1qiR1jSCAA5BDBCvJeqv4oZoVOunk4VkFOpKJUCy4bq9GVg5N6qCyeWW05WNBqSCC9WADAHTSFQ17r6DR00bL1KcuOzd2OTQunJgUKTprAJT4v0y5OPnTCBQA/0RQORcNuf5QtTQd7zclx0gQuMcjJ1O6+Lzs9AIBkBCvIa3YrbxKvyzbIMaUaEtiXIAgAYI1qIOQ1u5U3iddlG+Rk4iQIAgDYQ7CCvGanP0rUokIn2yAnE6+CIAAoZgQryGt2+qP0rNAxZRvkZOJVEAQAxYxgBXlvek1UK68fr0i4dwAQCZelTGbNNsjJxKsgCACKGaXLKBjZ9DXxosTYrAaSrHuxUA0EAM4+vwlWUPS8aN5GnxUASI9gBUghl11l6WALAKnRFA6w4Mb8ICeBR6peLAAAZwhWUNDMIGNdY4t+u3lP0s+znR/EkQ4A5A7HQMhLdnY6rIIMK17MDwIApMcxEAqa3UGCVkGGFS/mBwEA3EOfFeQVMwhJ3C0xj3PW7mxOG2Sk05f5QQAA77Czgrxhd6ejvKx/xqMfK7maHwQAcIZgBXnD7k5Hw+5WR/c1c1ZyNT8IAOAMwQpc43VfEfs7GPYPgOzMD2qJHbO8Y6ogBwDgLoIVuCIX5b12dzAmfbVSz2z7OGWQ0VMkzRrN+UFzV21TSNat87OZHwQAcIYEW/SZnaRXN9gdElh7+tCUQwpNN00erafm1GrTwqlpg6lshiQCANxFnxX0SUenoYuWrU+ZS5Kuh0k2nAwJdHO3h9b5AOAu+qwgZ5yU97rRet7c6UgMQqyOc6bXRFVXHXElyKB1PgD4x9dgZfTo0froo496PbZw4UI98MADPq0ITvlR3uskCCHIAID85/vOys9//nPNmTOn+/sTTjjBx9XAKTfKexkSCABIx/dgpby8XJFIxO9lIEt9Le9lSCAAIBPfq4GWLVumoUOH6txzz9V9992n9vb2tNe3tbUpHo/3+oJ/zPJeKbnyJlN5b6oqoubYMd2yapvuWf2OGna3qqMzr3PAAQB95Gs10PLlyzV+/HiddNJJ2rJlixYtWqQrrrhCjz76aMrnLFmyRPX19UmPUw3kL6c7JJmqiHpipwUACo+TaiDXg5VUwURPW7du1XnnnZf0+DPPPKN/+Id/0IEDBzR0qHU+Qltbm9ra2rq/j8fjGjlyJMFKADjJPWnY3aprfvO6rftalSUDAPKbr6XL8+bN06xZs9JeM3r0aMvHa2trJUm7du1KGayUlpaqtLS0T2uEN5wkvTqpDuo5pLCuOkJ/EwAoMq4HK5WVlaqsrMzqudu3b5ckRaP813Ohczr8z+zXsnzdB5o8ppKmbABQRHyrBmpoaNDrr7+uKVOmKBwOa+vWrVqwYIG+973v6bTTTvNrWXBZqqOhTFVEqazYsEsrNuwijwUAiohvCbbbtm3Trbfeqvfee09tbW0aNWqUZs2apTvuuEODBg2yfR/a7QdXpqTbVK3z7SCPBQDym68JtrlGsBJMZiCS+OZKDDKsAhq73J47BADIHSef3773WUHh6eg0VL+60XK3xHysfnWjOjoNTa+JatPCqXpqTq3+afJoSaknJVvdy5w7BAAoXAQrcJ2T4YbS36uIfjbzLP36+vGKhJ0l37o5dwgAEDy+t9tH4bEbPGze9WlSVU/PIYWbd32qFRt2Z7yP08oiAEB+YWcFrrMbPKzYsFsXLVuvtTubez1u7rQsqDtT0XBZymOhkLoSdlPNHQIAFAaCFbjOLEu2k3vSEjumuau2JQUsUt/mDgEACgfBClyXLshIlJhwm2h6TVQrLfJYIuEyypYBoEhQugzPOC1LfmpObcp2/U7mDgEAgs/X2UDIf24FBmay7PJ1H2jFhl0Zr0+XmOtk7hAAoLAQrKCXTF1nnepXEtLkMZW2ghWqegAAVghW0C1V19mW2DHdsmqbFlw2VqMrBzvebck0B8jsREtVDwDACsEKJNnrOrv85Q+7H3Oy22Im3M5dtU0h9Z4DRFUPACATqoEgKXPX2UTpSo6tUNUDAMgWOyuQ5LxlvaGuXZH61Y2qq47Y2hXp2Z2Wqh4AgF0EK5CUXXJrzxk/dit1qOoBADjFMRAkOes6m4hBggAALxGsQJKzrrOJKDkGAHiJYCWPdXQaatjdqhd2fKyG3a2W7eqdSJUEmwqDBAEAuUDOSp5yu3mbKTEJds+BL/TIyx9IouQYAOAPgpU8lK5529xV2/pcCpyYBHtm5ISkwCjiQmAEAIAdBCt5xk7ztjufeVvlZf1V+9Whrux6UHIMAPATwUqesdO87bOjx3Xdo39y5VjIRMkxAMAvJNjmGSdlwk67zAIAEEQEK3nGSZmweSxUv7qxz5VCAAD4hWAlzzht3tazyywAAPmIYCXPZNu8jS6zAIB8RbCSh5w2b5PoMgsAyF8EK3lqek1UmxZO1e9uukAnDuyf8jq6zAIA8h3BSp6waq3frySkyWMr9cAPzlZIycdCdJkFABQC+qzkgUyt9c1jIbrMAgAKUcgwjLyuaY3H4wqHw4rFYqqoqPB7Oa5L1Vrf3Cfp2Vq/o9OgyywAIC84+fxmZyXAMrXWD6mrh0pddUT9SkJ0mQUAFCRyVgIsU2t9t3qoWOXDAAAQFOysBJjd3ih96aGSKR8GAAC/EawEVEenoQOH22xdm66HSmIey4RRJ+nNjw5p/+Fj2nPgCz3y8gdJx0zmTKGe+TAAAPiFYCWArHY7rITUVfGTqoeK1X1KQlKmUx6rfBgAAPxCzkrAmNU/dgIVKXUPlVT3sZuOwkwhAEBQsLMSIOmqfxKl66Hi5D6ZpMuHoVQaAJALBCsBkqn6x/TTy7+uGydXpQwM7N7HjlT5MCTmAgByhWOgHEtXJmy3qqeyvDTtDoYbE5bTzRRKdcRkJuau3dnc598PAICJnZUcyrQbYXcycqbr+jphOV0+jNNGdQAA9JWnOyv33XefLrzwQg0aNEgnnnii5TV79+7VzJkzNXjwYFVWVurHP/6x2tvbvVyWL+zsRkysGqJouCxpIKHJ7gTlTPfJJBIuS1m2nKtGdQAAmDzdWWlvb9cPf/hDTZo0SY899ljSzzs6OnT55Zfr5JNP1qZNm9Ta2qrZs2fLMAz9x3/8h5dLyyknuxGLZ1Zr7qptCv3tZyYnE5T7lYRS3ieR+fMFl43V6MrBGRNlc9GoDgCAnjwNVurr6yVJTzzxhOXPX3rpJTU2Nmrfvn0aMWKEJOmhhx7SjTfeqPvuu69gBhM62Y1wa4Jyqvsk9llxel+3jqoAALDL15yVhoYG1dTUdAcqkvTtb39bbW1tevPNNzVlypSk57S1tamt7e+dXePxeE7W2pcyXae7EdNroqqrjvS5LNjqPj072GZzX/OIqSV2zHLHJlOjOgAAnPI1WGlpadHw4cN7PXbSSSdpwIABamlpsXzO0qVLu3dscqWvZbrZ7Ea4NUHZ6j59uW+6IyYnR1UAANjlOMF2yZIlCoVCab/eeOMN2/cLhZI/1AzDsHxckhYtWqRYLNb9tW/fPqd/giNulOm6lTgbFOYRUyTcOwhLl5gLAEC2HO+szJs3T7NmzUp7zejRo23dKxKJ6E9/+lOvxw4dOqTjx48n7biYSktLVVpaauv+feVWmW6m3QhD0qzzR+p/3/okbzrBunVUBQBAJo6DlcrKSlVWVrryyydNmqT77rtPzc3Nika7/mv8pZdeUmlpqSZMmODK7+gLu4mxy9d9oMljKtN+WKdKeA0P6i9JWv7yh92P5UsnWLeOqgAASMfTnJW9e/fq4MGD2rt3rzo6OrRjxw5J0pgxY3TCCSdo2rRpqq6u1g033KB/+7d/08GDB/WTn/xEc+bMCUQlkN3E2BUbdmnFhl0Zg4zE3Yg9B77QIy9/kLRzYx4xcaQCAIDHTeF+9rOfady4cVq8eLE+//xzjRs3TuPGjevOaenXr5/WrFmjsrIyTZ48Wf/4j/+oK6+8Ur/4xS+8XJZtTstvrfJYEtvrS10Jrt/9xgg9vXVvyiMmqeuIqcPumGQAAApUyDCMvP40jMfjCofDisViru/GdHQaumjZ+pRluqmcOLC/fnndeMW+OK571lhXEYUHDtA1v3k9472emlPLUQsAoOA4+fxmkGEaZmKsJEet6z87elzXPfon3fpk6iqilxutS7MT0QkWAFDsCFYySFWmmy1zh+a5HR/bup5OsACAYsfUZRt6JsZu3vWpVmzY3af7GZIOHjmuIYMH6NCRdjrBAgCQBjsrNpllugvqzuzTROOerjy3a8xA4r3oBAsAwN8RrDiUbR6LlbrqCJ1gAQDIgGOgLKRq8GZXzyOefiUhOsECAJAGwUqWzDyW13e36rYnt+mzo8dtPc/qiIdOsAAApMYxUAqJzdysmrP1Kwlp8thKPfCDsxWSvWMhjngAAHCGnRULa3c2Jx3xpGuln+pYKBou008v/7pOGlzKEQ8AAFmig22CtTubNXfVtqRyYjO8SLcr0tFpkHsCAIANTj6/2VnpoaPTUP3qxpTzekLqmtdTVx2xDELIPQEAwH3krPSwpelg2uoeQ1Jz7Ji2NB3M3aIAAChyBCs92J3Dw7weAAByh2ClB7tzeJjXAwBA7hCs9DCxakjaVvohdVX4MK8HAIDcIVjpIV0rfeb1AADgD4KVBGbPFOb1AAAQDJQuWzBb6dMzBQAA/xGspEDPFAAAgoFjIAAAEGgEKwAAINAIVgAAQKARrAAAgEAjWAEAAIFGsAIAAAKNYAUAAAQawQoAAAg0ghUAABBoed/B1jAMSVI8Hvd5JQAAwC7zc9v8HE8n74OVw4cPS5JGjhzp80oAAIBThw8fVjgcTntNyLAT0gRYZ2enPvnkE5WXlysUyt9Bg/F4XCNHjtS+fftUUVHh93J8w+vQhdehC69DF16HLrwOXQrldTAMQ4cPH9aIESNUUpI+KyXvd1ZKSkp06qmn+r0M11RUVOT1m88tvA5deB268Dp04XXowuvQpRBeh0w7KiYSbAEAQKARrAAAgEAjWAmI0tJSLV68WKWlpX4vxVe8Dl14HbrwOnThdejC69ClGF+HvE+wBQAAhY2dFQAAEGgEKwAAINAIVgAAQKARrAAAgEAjWPHRfffdpwsvvFCDBg3SiSeeaOs5N954o0KhUK+v2tpabxfqsWxeB8MwtGTJEo0YMUIDBw7UpZdeqnfeecfbhXrs0KFDuuGGGxQOhxUOh3XDDTfos88+S/ucQng//OpXv1JVVZXKyso0YcIE/fGPf0x7/caNGzVhwgSVlZXpq1/9qn7961/naKXecvI6vPLKK0n/7qFQSO+9914OV+y+V199VTNnztSIESMUCoX0/PPPZ3xOIb4fnL4Ohfp+6IlgxUft7e364Q9/qLlz5zp63vTp09Xc3Nz99eKLL3q0wtzI5nV48MEH9fDDD2vFihXaunWrIpGI6urqumdF5aNrr71WO3bs0Nq1a7V27Vrt2LFDN9xwQ8bn5fP74fe//73mz5+vu+++W9u3b9fFF1+sGTNmaO/evZbXNzU16Tvf+Y4uvvhibd++XXfddZd+/OMf65lnnsnxyt3l9HUwvf/++73+7ceOHZujFXvjyJEjOuecc7RixQpb1xfq+8Hp62AqtPdDLwZ89/jjjxvhcNjWtbNnzzauuOIKT9fjF7uvQ2dnpxGJRIwHHnig+7Fjx44Z4XDY+PWvf+3hCr3T2NhoSDJef/317scaGhoMScZ7772X8nn5/n6YOHGiccstt/R67Gtf+5px5513Wl5/xx13GF/72td6PXbzzTcbtbW1nq0xF5y+Dhs2bDAkGYcOHcrB6vwhyXjuuefSXlOo74ee7LwOxfB+YGclD73yyisaNmyYzjjjDM2ZM0f79+/3e0k51dTUpJaWFk2bNq37sdLSUl1yySV67bXXfFxZ9hoaGhQOh3XBBRd0P1ZbW6twOJzxb8rX90N7e7vefPPNXv+OkjRt2rSUf3NDQ0PS9d/+9rf1xhtv6Pjx456t1UvZvA6mcePGKRqN6lvf+pY2bNjg5TIDqRDfD31RyO8HgpU8M2PGDP3ud7/T+vXr9dBDD2nr1q2aOnWq2tra/F5azrS0tEiShg8f3uvx4cOHd/8s37S0tGjYsGFJjw8bNizt35TP74cDBw6oo6PD0b9jS0uL5fVffvmlDhw44NlavZTN6xCNRvWf//mfeuaZZ/Tss8/qzDPP1Le+9S29+uqruVhyYBTi+yEbxfB+yPupy0GzZMkS1dfXp71m69atOu+887K6/9VXX939f9fU1Oi8887TqFGjtGbNGn3/+9/P6p5e8Pp1kKRQKNTre8Mwkh7zm93XQUr+e6TMf1O+vB/ScfrvaHW91eP5xsnrcOaZZ+rMM8/s/n7SpEnat2+ffvGLX+ib3/ymp+sMmkJ9PzhRDO8HghWXzZs3T7NmzUp7zejRo137fdFoVKNGjdKHH37o2j3d4OXrEIlEJHX9V1U0Gu1+fP/+/Un/leU3u6/DW2+9pb/+9a9JP/v0008d/U1BfT9YqaysVL9+/ZJ2D9L9O0YiEcvrv/KVr2jo0KGerdVL2bwOVmpra7Vq1Sq3lxdohfh+cEuhvR8IVlxWWVmpysrKnP2+1tZW7du3r9eHdhB4+TpUVVUpEolo3bp1GjdunKSuc/+NGzdq2bJlnvzObNl9HSZNmqRYLKYtW7Zo4sSJkqQ//elPisViuvDCC23/vqC+H6wMGDBAEyZM0Lp163TVVVd1P75u3TpdccUVls+ZNGmSVq9e3euxl156Seedd5769+/v6Xq9ks3rYGX79u158e/upkJ8P7il4N4Pfmb3FruPPvrI2L59u1FfX2+ccMIJxvbt243t27cbhw8f7r7mzDPPNJ599lnDMAzj8OHDxr/+678ar732mtHU1GRs2LDBmDRpknHKKacY8Xjcrz+jz5y+DoZhGA888IARDoeNZ5991nj77beNa665xohGo3n9OkyfPt34xje+YTQ0NBgNDQ3G2WefbXz3u9/tdU2hvR+efvppo3///sZjjz1mNDY2GvPnzzcGDx5s7NmzxzAMw7jzzjuNG264ofv6P//5z8agQYOMBQsWGI2NjcZjjz1m9O/f3/jv//5vv/4EVzh9HZYvX24899xzxgcffGDs3LnTuPPOOw1JxjPPPOPXn+CKw4cPd//vX5Lx8MMPG9u3bzc++ugjwzCK5/3g9HUo1PdDTwQrPpo9e7YhKelrw4YN3ddIMh5//HHDMAzjiy++MKZNm2acfPLJRv/+/Y3TTjvNmD17trF3715//gCXOH0dDKOrfHnx4sVGJBIxSktLjW9+85vG22+/nfvFu6i1tdW47rrrjPLycqO8vNy47rrrkkoRC/H98Mtf/tIYNWqUMWDAAGP8+PHGxo0bu382e/Zs45JLLul1/SuvvGKMGzfOGDBggDF69Ghj5cqVOV6xN5y8DsuWLTNOP/10o6yszDjppJOMiy66yFizZo0Pq3aXWYKb+DV79mzDMIrn/eD0dSjU90NPIcP4WzYSAABAAFG6DAAAAo1gBQAABBrBCgAACDSCFQAAEGgEKwAAINAIVgAAQKARrAAAgEAjWAEAAIFGsAIAAAKNYAUAAAQawQoAAAg0ghUAABBo/x9iU3JjHhOd5AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(x,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e1cb66-5337-43ab-ad76-acf76022625d",
   "metadata": {},
   "outputs": [],
   "source": [
    "## 创建模型\n",
    "class Linear(nn.Module): #创建所有预测模型都要继承nn.Module类\n",
    "    \n",
    "    def __init__(self):\n",
    "        # 定义模型参数\n",
    "        super().__init__()\n",
    "        self.a = nn.Parameter(torch.zeros(()))\n",
    "        self.b = nn.Parameter(torch.zeros(()))\n",
    "\n",
    "    def forward(self,x):\n",
    "        # 定义前向传播算法\n",
    "        return self.a * x + self.b\n",
    "\n",
    "    def string(self):\n",
    "        # 监控类的状态\n",
    "        return f'y = {self.a.item():.3f} * x + {self.b.item():.3f}'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "58448b8e-7ed0-474c-a551-22f9b9667ee7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
       "        0., 0., 0., 0.], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = Linear()\n",
    "model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0319d635-3769-49b2-993c-df807b8e6d57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter containing:\n",
       " tensor(0., requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor(0., requires_grad=True)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(model.parameters()) #打印模型的参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c5dea1de-9cf0-45ac-86ee-e30f26bcbc4f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y = 1.972 * x + 0.968\n",
      "y = 3.553 * x + 1.743\n",
      "y = 4.821 * x + 2.362\n",
      "y = 5.838 * x + 2.858\n",
      "y = 6.654 * x + 3.255\n",
      "y = 7.308 * x + 3.572\n",
      "y = 7.833 * x + 3.826\n",
      "y = 8.254 * x + 4.029\n",
      "y = 8.591 * x + 4.191\n",
      "y = 8.862 * x + 4.321\n",
      "y = 9.079 * x + 4.425\n",
      "y = 9.253 * x + 4.508\n",
      "y = 9.393 * x + 4.575\n",
      "y = 9.505 * x + 4.628\n",
      "y = 9.594 * x + 4.670\n",
      "y = 9.666 * x + 4.704\n",
      "y = 9.724 * x + 4.732\n",
      "y = 9.771 * x + 4.754\n",
      "y = 9.808 * x + 4.771\n",
      "y = 9.837 * x + 4.785\n"
     ]
    }
   ],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "learning_rate = 0.1\n",
    "model = Linear()\n",
    "optimizer = optim.SGD(model.parameters(),lr = learning_rate )\n",
    "\n",
    "for t in range(20):\n",
    "    # 清空梯度\n",
    "    optimizer.zero_grad()\n",
    "    y_pred = model(x)\n",
    "    #定义损失函数\n",
    "    loss = (y-y_pred).pow(2).mean()\n",
    "    # 计算梯度\n",
    "    loss.backward()\n",
    "    # 更新参数\n",
    "    optimizer.step()\n",
    "    print(model.string())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5915969c-2b1b-42f1-8ec7-1959c378a0da",
   "metadata": {},
   "source": [
    "## 下面完成代码的梯度下降的细节"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "6ae0481f-2132-41a4-b730-494452e2407d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y = 1.972 * x + 0.968\n",
      "y = 3.553 * x + 1.743\n",
      "y = 4.821 * x + 2.362\n",
      "y = 5.838 * x + 2.858\n",
      "y = 6.654 * x + 3.255\n",
      "y = 7.308 * x + 3.572\n",
      "y = 7.833 * x + 3.826\n",
      "y = 8.254 * x + 4.029\n",
      "y = 8.591 * x + 4.191\n",
      "y = 8.862 * x + 4.321\n",
      "y = 9.079 * x + 4.425\n",
      "y = 9.253 * x + 4.508\n",
      "y = 9.393 * x + 4.575\n",
      "y = 9.505 * x + 4.628\n",
      "y = 9.594 * x + 4.670\n",
      "y = 9.666 * x + 4.704\n",
      "y = 9.724 * x + 4.732\n",
      "y = 9.771 * x + 4.754\n",
      "y = 9.808 * x + 4.771\n",
      "y = 9.837 * x + 4.785\n"
     ]
    }
   ],
   "source": [
    "model = Linear()\n",
    "optimizer = optim.SGD(model.parameters(),lr = learning_rate )\n",
    "for t in range(20):\n",
    "    # 清空梯度\n",
    "    y_pred = model(x)\n",
    "    #定义损失函数\n",
    "    loss = (y-y_pred).pow(2).mean()\n",
    "    # 计算梯度\n",
    "    loss.backward()\n",
    "    # 更新参数\n",
    "    with torch.no_grad():\n",
    "        for parm in model.parameters():\n",
    "            parm -= learning_rate * parm.grad\n",
    "            parm.grad = torch.zeros(parm.grad.shape)\n",
    "    print(model.string())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff93e1f4-6268-45e8-9258-83137dddc2f7",
   "metadata": {
    "editable": true,
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "source": [
    "## 工程上的优化,使用随机下降算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "4421f4cd-6c63-41dc-a548-9c86d836f8d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y = 2.787 * x + -1.898\n",
      "y = 9.181 * x + 4.941\n",
      "y = 9.939 * x + 4.854\n",
      "y = 10.027 * x + 4.877\n",
      "y = 9.878 * x + 4.768\n"
     ]
    }
   ],
   "source": [
    "# 随机梯度下降算法\n",
    "model = Linear()\n",
    "batch_size = 16\n",
    "optimizer = optim.SGD(model.parameters(),lr = learning_rate )\n",
    "for t in range(50):\n",
    "    ix = (t * batch_size) % len(x)\n",
    "    xx = x[ix:(ix+batch_size)]\n",
    "    yy = y[ix:(ix+batch_size)]\n",
    "    # 清空梯度\n",
    "    y_pred = model(xx)\n",
    "    #定义损失函数\n",
    "    loss = (yy-y_pred).pow(2).mean()\n",
    "    # 计算梯度\n",
    "    loss.backward()\n",
    "    # 更新参数\n",
    "    with torch.no_grad():\n",
    "        for parm in model.parameters():\n",
    "            parm -= learning_rate * parm.grad\n",
    "            parm.grad = torch.zeros(parm.grad.shape)\n",
    "    if t % 10 == 0:\n",
    "        print(model.string())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9bc289fb-0b99-4413-b64e-13d2b3a4a064",
   "metadata": {},
   "source": [
    "## view 和 reshape的区别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "01591d4f-21eb-44ab-98d4-84a986f18c75",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 5])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = torch.tensor(range(0,10))\n",
    "print(a.shape)\n",
    "a.view(2,-1).shape #转化为(2,5)形状的张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "b266c3bb-86e9-45c4-90d5-40c6d554f032",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 2])\n",
      "error occurrence:\n",
      "```\n",
      "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.\n",
      "```\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[0, 5, 1, 6, 2, 7, 3, 8, 4, 9]])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b = a.view(2,-1).T\n",
    "print(b.shape)\n",
    "# 报错代码\n",
    "try:\n",
    "    b.view(1,10)\n",
    "except Exception as e:\n",
    "    # 打印错误信息\n",
    "    print(\"error occurrence:\")\n",
    "    print(f'```\\n{str(e)}\\n```')\n",
    "b.reshape(1,10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
